from mynumpy import *
import sample
import map
from matplotlib import pyplot as plt
import seaborn as sns

def start_figure(mysize=(5,5),bot=.14,top=.93):
    fig = plt.figure(figsize=mysize)
    sns.set(font_scale=2)
    sns.set_style('white')
    plt.subplots_adjust(bottom=bot)
    plt.subplots_adjust(top   =top)
    plt.subplots_adjust(left  =0.14)
    plt.subplots_adjust(right =0.93)
    #plt.plot(omega[0,:],omega[1,:],'.')
    return fig

def nice_ticks(fig):
    # restrict ticks to integers!
    from pylab import MaxNLocator
    ya = fig.gca().get_yaxis()
    ya.set_major_locator(MaxNLocator(integer=True))
    xa = fig.gca().get_xaxis()
    xa.set_major_locator(MaxNLocator(integer=True))

def makeplot(sampler,tform,ndims,nsamps,label,A,mu,bins=None,extralabel=''):
    ms = 20

    omega = sampler(ndims,nsamps)
    #z = map.uniform_to_gauss(omega,A,mu)
    z = tform[0](omega,A,mu)

    n_colors = z.shape[1]
    #cols = sns.color_palette("RdBu", n_colors=z.shape[1]) #"Blues_d"
    #cols = sns.hls_palette(z.shape[1], l=.3, s=.8)
    #cols = sns.diverging_palette(255, 133, l=60, n=nsamps, center="dark")
    #cols = sns.diverging_palette(255, 133, l=60, n=nsamps, center="dark")
    #cols = sns.diverging_palette(250, 15, s=75, l=40, n=nsamps, center="dark")
    #cols = sns.color_palette("viridis", n_colors=z.shape[1])
    cols = [(k*.8,0,(1-k)*.8) for k in arange(0,1,1/n_colors)]
    #print(cols)

    fig=start_figure()
    plt.scatter(omega[0,:],omega[1,:],ms,color=cols)
    plt.xlabel('$u_1$')
    plt.ylabel('$u_2$')

    plt.axis('equal')
    plt.xlim([0-.01,1+.01])
    plt.ylim([0-.01,1+.01])

    if bins is not None:
        for x in arange(0,1+1e-5,1/bins):
            plt.plot([x,x],[0,1],'k-',alpha=0.2,linewidth=.75)
            plt.plot([0,1],[x,x],'k-',alpha=0.2,linewidth=.75)

    nice_ticks(fig)
    plt.axis('off')

    #plt.title(label+extralabel)
    plt.title(label+' base samples')

    plt.savefig(tform[1]+'_'+label+'_omega_noyaxis.pdf',transparent=True)
    plt.arrow(0,0,.1,.0,head_width=.1/4.5)
    plt.text(.01,-.1,'$u_1$')
    plt.arrow(0,0,0,.1,head_width=.1/4.5)
    plt.text(-.15,.01,'$u_2$')

    plt.savefig(tform[1]+'_'+label+'_omega.pdf',transparent=True)
    plt.ylabel('')
    plt.yticks([])
    

    plt.close()

    fig = start_figure((5,3.5),bot=.17,top=1.0)

    #plt.plot(z[0,:],z[1,:],'.')
    plt.scatter(z[0,:],z[1,:],ms,color=cols)

    # make a contour
    theta = arange(0,2*pi,.01)
    for r in [2]:
        cont  = r * A @ vstack([cos(theta),sin(theta)]) + expand_dims(mu,axis=1)
        plt.plot(cont[0,:],cont[1,:],color='darkgreen',alpha=0.5,linewidth=1.5)
    zo = arange(0,1+1e-5,.01)
    if bins is not None and bins > 1:
        for x in arange(0,1,1/bins):
            start = vstack([x+0*zo,zo])
            end   = tform[0](start,A,mu)
            plt.plot(end[0,:],end[1,:],'k-',alpha=0.2,linewidth=.75)
            start = vstack([zo,x+0*zo])
            end   = tform[0](start,A,mu)
            plt.plot(end[0,:],end[1,:],'k-',alpha=0.2,linewidth=.75)

    # use matplotlib to get equal axis ratio then force it
    #print('ratio',ygap/xgap)
    #plt.xlim([-1.75,1.75])
    #plt.ylim([-1.75,-1.75+2*1.75*ygap/xgap])
    plt.axis([-1.75,1.75,-1.55,-1.55+2*1.75*.6666])
    #plt.axis('off')
    nice_ticks(fig)

    sns.despine(left=True, bottom=True, right=True)
    plt.xlabel('$z_1$')
    plt.ylabel('$z_2$')
    #plt.title(label+extralabel)
    #plt.text(0.2,1,extralabel)
    plt.axis('off')
    plt.savefig(tform[1]+'_'+label+'_z_noyaxis.pdf',transparent=True)
    plt.arrow(-.7,-1.2,.5,.0,head_width=.1)
    plt.text(-1,-1.25,'$z_1$')
    plt.arrow(-1.1,-1.0,0,0.5,head_width=.1)
    plt.text(-1.45,-.95,'$z_2$')
    plt.savefig(tform[1]+'_'+label+'_z.pdf',transparent=True)
    plt.close()

ndims = 2
nsamps = 49
nsamps = 5**2
#nsamps = 2**2
#nsamps = 4

#A = array([[.5,.1],[-.25,.4]])
#A = array([[.7,.1],[-.25,.4]])
A = array([[.7,.1],[-.25,.3]])
mu = array([.2,-.35])

tforms = [(map.circ,'circ',' circular'),(map.icdf,'icdf',' cartesian')]
for tform in tforms:
    seed(3)
    makeplot(sample.iid       ,tform,ndims,nsamps,'iid',A,mu,extralabel=tform[2],bins=1)
    makeplot(sample.stratified,tform,ndims,nsamps,'stratified',A,mu,bins=nsamps**(1/ndims),extralabel=tform[2])
    makeplot(sample.qmc       ,tform,ndims,nsamps,'qmc',A,mu,bins=nsamps**(1/ndims),extralabel=tform[2])
    
    # these work but without the nice symmetry between starting distributions
    # that's because we need to reflect the starting samples in different ways
    # to achieve antithetic sampling with  
    makeplot(sample.anti      ,tform,ndims,nsamps,'antithetic',A,mu,bins=2,extralabel=tform[2]) # hack to show reflection axes
    # if tform[1]=='icdf':
    #     makeplot(sample.anti_both ,tform,ndims,nsamps,'antithetic',A,mu)
    # else:
    #     makeplot(sample.anti_1st  ,tform,ndims,nsamps,'antithetic',A,mu)
    makeplot(sample.latin     ,tform,ndims,nsamps,'latin',A,mu,bins=nsamps,extralabel=tform[2])

# todo: strat + QMC
# todo: apply to radius only